-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat (llm/learned_round): fast block update #1110
Conversation
src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Outdated
Show resolved
Hide resolved
src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Outdated
Show resolved
Hide resolved
src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Outdated
Show resolved
Hide resolved
src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Outdated
Show resolved
Hide resolved
src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Outdated
Show resolved
Hide resolved
src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I'd open an issue to refactor and rely on save_inputs_output as much as possible, to prevent duplicating the block forward code.
@@ -602,26 +603,28 @@ def apply_learned_round( | |||
|
|||
# Initialize cache to store partial inputs and outputs for each block | |||
cache.initialize_cache() | |||
|
|||
floating_point_datasets = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
floating_point_datasets is no longer used after the changes, right?
Reason for this PR
Inter-block update in learned round might be super slow for big models
Changes Made in this PR
We assume that blocks are sequential, so the output of each block is the input to the next.
Furthermore, we assume all kwargs don't change (typical in LLM).
We can run 2 block forwards instead of going through the entire model all over twice.
Testing Summary
NA
Risk Highlight
Limitations described above. Flag should be set to False unless the user knows what they're doing. Potentially to improve in the future.
Checklist
dev
branch.